{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-09T13:28:52.792054700Z",
     "start_time": "2023-05-09T13:28:52.687059100Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "from common.functions import softmax\n",
    "from rnnlm import Rnnlm\n",
    "from better_rnnlm import BetterRnnlm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "class RnnlmGen(Rnnlm):\n",
    "    def generate(self, start_id, skip_ids=None, sample_szie=100):\n",
    "        word_ids = [start_id]\n",
    "        x = start_id\n",
    "        while len(word_ids) < sample_szie:\n",
    "            x = np.array(x).reshape(1, 1)\n",
    "            score = self.predict(x)\n",
    "            p = softmax(score.flatten())\n",
    "\n",
    "            sampled = np.random.choice(len(p), size=1, p=p)\n",
    "\n",
    "            if (skip_ids is None) or (sampled not in skip_ids):\n",
    "                x = sampled\n",
    "                word_ids.append(int(x))\n",
    "\n",
    "        return word_ids"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T13:39:33.476853500Z",
     "start_time": "2023-05-09T13:39:33.470854100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 这里先在完全没有学习的状态（即权重参数是随机初始值的状态）下生成文本"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "from dataset import ptb"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T13:40:03.815616400Z",
     "start_time": "2023-05-09T13:40:03.802617400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "corpus, word_to_id, id_to_word = ptb.load_data('train')\n",
    "vocab_size = len(word_to_id)\n",
    "corpus_size = len(corpus)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T14:25:50.844098500Z",
     "start_time": "2023-05-09T14:25:50.827098600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "model = RnnlmGen()\n",
    "model.load_params('./Rnnlm.pkl')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T14:29:21.602266100Z",
     "start_time": "2023-05-09T14:29:21.533266Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "# 设定start单词和skip单词\n",
    "start_word = 'you'\n",
    "start_id = word_to_id[start_word]\n",
    "skip_words = ['N', '<unk>', '$']\n",
    "skip_ids = [word_to_id[w] for w in skip_words]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T14:29:40.660494Z",
     "start_time": "2023-05-09T14:29:40.647495600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "you 're impressive.\n",
      " but mr. engelken says he will tell you even to offer that he was strongly discouraging whether competitors would n't provide advertising its portable control and will reduce innovation.\n",
      " last year 's level of volume will hold about two parties for most of whom are making market agreements ventures falling.\n",
      " reward.\n",
      " which can appear to leave japanese mr. rican glass out of educators and take talks with an geographic export goal of investments.\n",
      " the results are provided for japanese accounts for wide credits and costs to visit in the florida making decliners\n"
     ]
    }
   ],
   "source": [
    "# 生成文本\n",
    "word_ids = model.generate(start_id, skip_ids)\n",
    "txt = ' '.join([id_to_word[i] for i in word_ids])\n",
    "txt = txt.replace(' <eos>', '.\\n')\n",
    "print(txt)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T14:30:03.493158100Z",
     "start_time": "2023-05-09T14:30:03.292383300Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
